{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "dcde8953-019b-4893-9250-c3be9f45ebb8",
   "metadata": {},
   "source": [
    "# 加载数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "23653162-98c5-4754-8de5-6617ca31ce99",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "import jieba\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.metrics import ndcg_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0184abe4-173c-4917-a83b-76916fb6c9bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = load_dataset('./hugging-face-dataset/C-MTEB/DuRetrieval')\n",
    "qrels = load_dataset('./hugging-face-dataset/C-MTEB/DuRetrieval-qrels')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "63bf646e-ea2f-4cac-992f-d88d64fa4dc1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    corpus: Dataset({\n",
       "        features: ['id', 'text'],\n",
       "        num_rows: 100001\n",
       "    })\n",
       "    queries: Dataset({\n",
       "        features: ['id', 'text'],\n",
       "        num_rows: 2000\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "454511fb-7926-4dd4-a1ba-bf3125735131",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    dev: Dataset({\n",
       "        features: ['qid', 'pid', 'score'],\n",
       "        num_rows: 9839\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "qrels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "815f5bad-1bc7-42ea-a366-f93a4039b9c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "qrels = pd.DataFrame(qrels['dev'])\n",
    "qrels = qrels.groupby('qid')['pid'].apply(list).to_dict()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80655592-2d47-4a28-89fa-48d5eafe4195",
   "metadata": {},
   "source": [
    "# BM25 (gensim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "f7badbc1-52b4-45dd-8529-cca6388ca2cb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.62295806 0.         0.        ]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'今天 很开心'"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from gensim.corpora import Dictionary\n",
    "from gensim.models import TfidfModel, OkapiBM25Model\n",
    "from gensim.similarities import SparseMatrixSimilarity\n",
    "import numpy as np\n",
    "corpus = [\n",
    "    \"今天 很开心\",\n",
    "    \"明天 也 很开心\",\n",
    "    \"Python is a programming language\"\n",
    "]\n",
    "corpus = [doc.lower().split() for doc in corpus]#[[\"Hello\", \"world\"], [\"bar\", \"bar\"], [\"foo\", \"bar\"]]\n",
    "dictionary = Dictionary(corpus)\n",
    "bm25_model = OkapiBM25Model(dictionary=dictionary)\n",
    "bm25_corpus = bm25_model[list(map(dictionary.doc2bow, corpus))]\n",
    "bm25_index = SparseMatrixSimilarity(bm25_corpus, num_docs=len(corpus), num_terms=len(dictionary),\n",
    "                                   normalize_queries=False, normalize_documents=False)\n",
    "query = [\"learn\", '今天']\n",
    "tfidf_model = TfidfModel(dictionary=dictionary, smartirs='bnn')  # Enforce binary weighting of queries\n",
    "tfidf_query = tfidf_model[dictionary.doc2bow(query)]\n",
    "similarities = bm25_index[tfidf_query]\n",
    "print(similarities)\n",
    "best_document = corpus[np.argmax(similarities)]\n",
    "\" \".join(best_document)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "71b9d370-8601-4689-b622-ab4d89b8270e",
   "metadata": {},
   "outputs": [],
   "source": [
    "corpus = [\n",
    "    jieba.lcut(x['text']) for x in data['corpus']\n",
    "]\n",
    "dictionary = Dictionary(corpus)\n",
    "bm25_model = OkapiBM25Model(dictionary=dictionary)\n",
    "bm25_corpus = bm25_model[list(map(dictionary.doc2bow, corpus))]\n",
    "bm25_index = SparseMatrixSimilarity(bm25_corpus, num_docs=len(corpus), num_terms=len(dictionary),\n",
    "                                   normalize_queries=False, normalize_documents=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "1a6442fa-dde0-4330-adec-8444ca569e11",
   "metadata": {},
   "outputs": [],
   "source": [
    "# https://radimrehurek.com/gensim/models/tfidfmodel.html\n",
    "tfidf_model = TfidfModel(dictionary=dictionary, smartirs='afc')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "f12a38cc-75f7-4bf0-b6da-81191f2bddc3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 41.9 ms, sys: 0 ns, total: 41.9 ms\n",
      "Wall time: 39.4 ms\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'该 经验 图片 、 文字 中 可能 存在 外站 链接 或 电话号码 等 , 请 注意 识别 , 谨防 上当受骗 ! 百度 经验 : jingyan . baidu . com 在 之前 win7 的 设置 中 大部分 操作 都 是 在 控制面板 中 完成 的 , 可是 在 新 的 win10 系统 中 却 发现 控制面板 很难 找 了 , 那么 在 哪里 能 找到 呢 ? 让 我 来 告诉 你 吧 ~ 百度 经验 : jingyan . baidu . com 百度 经验 : jingyan . baidu . com1   打开 开始菜单 , 点击 所有 应用   步骤 阅读   2   找到 windows 系统 , 点击 打开 下级菜单   步骤 阅读   3   在 当中 找到 控制面板 点击 打开 即可   步骤 阅读   END 百度 经验 : jingyan . baidu . com'"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "query = '请注意识别,谨防上当受骗!'\n",
    "query = jieba.lcut(query)\n",
    "tfidf_query = tfidf_model[dictionary.doc2bow(query)]\n",
    "similarities = bm25_index[tfidf_query]\n",
    "\n",
    "best_document = corpus[np.argmax(similarities)]\n",
    "\" \".join(best_document)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "24502d17-258f-42be-b4a7-460b39c6b5d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "qrels_dict = {x['qid']: x['pid'] for x in qrels['dev']}\n",
    "pid_array = np.array([x['id'] for x in data['corpus']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "1d60ff38-50fc-4db2-ad5b-376e7818cacf",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2000/2000 [06:28<00:00,  5.14it/s]\n"
     ]
    }
   ],
   "source": [
    "topn = 30\n",
    "query_ndcg_score = []\n",
    "for query_data in tqdm(data['queries']):\n",
    "    query = jieba.lcut(query_data['text'])\n",
    "    query_qid = query_data['id']\n",
    "    query_pids = qrels[query_qid]\n",
    "    \n",
    "    tfidf_query = tfidf_model[dictionary.doc2bow(query)]\n",
    "    similarities = np.array(bm25.get_scores(query))\n",
    "    top_results = similarities.argsort()[::-1][:topn]\n",
    "    top_results = [data['corpus'][int(x)]['id'] for x in top_results]\n",
    "    true_relevance  = [[x in query_pids for x in top_results]]\n",
    "    scores = [list(similarities[similarities.argsort()[::-1][:topn]])]\n",
    "\n",
    "    query_ndcg_score.append(ndcg_score(true_relevance, scores))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "485f69bd-b2a7-4f29-a589-75fcc2bf08a4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.7077681714845733, 0.3301973263950749)"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(query_ndcg_score), np.std(query_ndcg_score)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8e7fde7-5c08-41bc-bd93-90c0ad5d222a",
   "metadata": {},
   "source": [
    "# BM25 (rank_bm25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "1a0b0957-f7a9-47bd-9c56-53f1e4cad7f6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.        , 0.93729472, 0.        ])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from rank_bm25 import BM25Okapi\n",
    "\n",
    "corpus = [\n",
    "    \"Hello there good man!\",\n",
    "    \"It is quite windy in London\",\n",
    "    \"How is the weather today?\"\n",
    "]\n",
    "\n",
    "tokenized_corpus = [doc.split(\" \") for doc in corpus]\n",
    "\n",
    "bm25 = BM25Okapi(tokenized_corpus)\n",
    "\n",
    "query = \"windy London\"\n",
    "tokenized_query = query.split(\" \")\n",
    "bm25.get_scores(tokenized_query)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "791fbec4-e32d-4005-8241-224caf871d58",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model cost 0.539 seconds.\n",
      "Prefix dict has been built successfully.\n"
     ]
    }
   ],
   "source": [
    "corpus = [jieba.lcut(x['text']) for x in data['corpus']]\n",
    "bm25 = BM25Okapi(corpus)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "0a95206e-b3e4-40fa-a34b-07629e17d30f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2000/2000 [06:26<00:00,  5.17it/s]\n"
     ]
    }
   ],
   "source": [
    "topn = 30\n",
    "query_ndcg_score = []\n",
    "for query_data in tqdm(data['queries']):\n",
    "    query = jieba.lcut(query_data['text'])\n",
    "    query_qid = query_data['id']\n",
    "    query_pids = qrels[query_qid]\n",
    "    \n",
    "    similarities = np.array(bm25.get_scores(query))\n",
    "    top_results = similarities.argsort()[::-1][:topn]\n",
    "    top_results = [data['corpus'][int(x)]['id'] for x in top_results]\n",
    "    true_relevance  = [[x in query_pids for x in top_results]]\n",
    "    scores = [list(similarities[similarities.argsort()[::-1][:topn]])]\n",
    "\n",
    "    query_ndcg_score.append(ndcg_score(true_relevance, scores))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "d049daaa-e00e-4dee-964d-c27e1a2e4815",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.7077681714845733, 0.3301973263950749)"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(query_ndcg_score), np.std(query_ndcg_score)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d76d2b3-1da3-4a0d-9923-6002b5304dcf",
   "metadata": {},
   "source": [
    "# M3E"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4dca0864-dc43-410e-839e-b890f99e9f78",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sentence: * Moka 此文本嵌入模型由 MokaAI 训练并开源，训练脚本使用 uniem\n",
      "Embedding: (512,)\n",
      "\n",
      "Sentence: * Massive 此文本嵌入模型通过**千万级**的中文句对数据集进行训练\n",
      "Embedding: (512,)\n",
      "\n",
      "Sentence: * Mixed 此文本嵌入模型支持中英双语的同质文本相似度计算，异质文本检索等功能，未来还会支持代码检索，ALL in one\n",
      "Embedding: (512,)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "model = SentenceTransformer('./hugging-face-model/moka-ai/m3e-small/')\n",
    "\n",
    "#Our sentences we like to encode\n",
    "sentences = [\n",
    "    '* Moka 此文本嵌入模型由 MokaAI 训练并开源，训练脚本使用 uniem',\n",
    "    '* Massive 此文本嵌入模型通过**千万级**的中文句对数据集进行训练',\n",
    "    '* Mixed 此文本嵌入模型支持中英双语的同质文本相似度计算，异质文本检索等功能，未来还会支持代码检索，ALL in one'\n",
    "]\n",
    "\n",
    "#Sentences are encoded by calling model.encode()\n",
    "embeddings = model.encode(sentences)\n",
    "\n",
    "#Print the embeddings\n",
    "for sentence, embedding in zip(sentences, embeddings):\n",
    "    print(\"Sentence:\", sentence)\n",
    "    print(\"Embedding:\", embedding.shape)\n",
    "    print(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "69c04b3a-1586-41c7-9f97-64fbfe961a83",
   "metadata": {},
   "outputs": [],
   "source": [
    "corpus = [x['text'] for x in data['corpus']][:]\n",
    "embeddings = model.encode(corpus, normalize_embeddings=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "81ce3723-6408-4345-a9fd-2e2c6c705fcb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2000/2000 [01:32<00:00, 21.68it/s]\n"
     ]
    }
   ],
   "source": [
    "topn = 30\n",
    "query_ndcg_score = []\n",
    "for query_data in tqdm(data['queries']):\n",
    "    query_feat = model.encode(query_data['text'], normalize_embeddings=True)\n",
    "    query_qid = query_data['id']\n",
    "    query_pids = qrels[query_qid]\n",
    "    \n",
    "    similarities = query_feat.dot(embeddings.T)\n",
    "    top_results = similarities.argsort()[::-1][:topn]\n",
    "    top_results = [data['corpus'][int(x)]['id'] for x in top_results]\n",
    "    true_relevance  = [[x in query_pids for x in top_results]]\n",
    "    scores = [list(similarities[similarities.argsort()[::-1][:topn]])]\n",
    "\n",
    "    query_ndcg_score.append(ndcg_score(true_relevance, scores))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "99476225-c09f-4341-94e2-c8ae48f497fe",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.7614511933399406, 0.27839499502211373)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(query_ndcg_score), np.std(query_ndcg_score)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fcfa4b71-e073-4774-a7ba-3f69a01d5954",
   "metadata": {},
   "source": [
    "# BGE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "2a6dd465-e645-49bf-83b9-a51bbb9b01d0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.86921406 0.8665448 ]\n",
      " [0.899487   0.8784326 ]]\n"
     ]
    }
   ],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "sentences_1 = [\"样例数据-1\", \"样例数据-2\"]\n",
    "sentences_2 = [\"样例数据-3\", \"样例数据-4\"]\n",
    "model = SentenceTransformer('./hugging-face-model/BAAI/bge-small-zh-v1.5/')\n",
    "embeddings_1 = model.encode(sentences_1, normalize_embeddings=True)\n",
    "embeddings_2 = model.encode(sentences_2, normalize_embeddings=True)\n",
    "similarity = embeddings_1 @ embeddings_2.T\n",
    "print(similarity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "b5085560-3759-40f8-800a-e503c2ae9e87",
   "metadata": {},
   "outputs": [],
   "source": [
    "corpus = [x['text'] for x in data['corpus']][:]\n",
    "embeddings = model.encode(corpus, normalize_embeddings=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "c75a4c0e-138b-423e-b4b9-f94ae4678f73",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2000/2000 [01:33<00:00, 21.35it/s]\n"
     ]
    }
   ],
   "source": [
    "topn = 30\n",
    "query_ndcg_score = []\n",
    "for query_data in tqdm(data['queries']):\n",
    "    query_feat = model.encode(query_data['text'], normalize_embeddings=True)\n",
    "    query_qid = query_data['id']\n",
    "    query_pids = qrels[query_qid]\n",
    "    \n",
    "    similarities = query_feat.dot(embeddings.T)\n",
    "    top_results = similarities.argsort()[::-1][:topn]\n",
    "    top_results = [data['corpus'][int(x)]['id'] for x in top_results]\n",
    "    true_relevance  = [[x in query_pids for x in top_results]]\n",
    "    scores = [list(similarities[similarities.argsort()[::-1][:topn]])]\n",
    "\n",
    "    query_ndcg_score.append(ndcg_score(true_relevance, scores))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "cf693eef-edce-4caf-a490-9cefa3611743",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.8629532887126161, 0.221727574464936)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(query_ndcg_score), np.std(query_ndcg_score)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a82bba0b-1bcb-4a43-a554-cf025367d34b",
   "metadata": {},
   "source": [
    "# BGE with ReRank"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "f23ebcb7-9b74-46f4-8e14-dd020ae83910",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([-6.0319,  8.9878], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained('./hugging-face-model/BAAI/bge-reranker-base/')\n",
    "rerank_model = AutoModelForSequenceClassification.from_pretrained('./hugging-face-model/BAAI/bge-reranker-base/')\n",
    "rerank_model.cuda()\n",
    "rerank_model.eval()\n",
    "\n",
    "pairs = [['我很开心', '我很沮丧'], ['我很开心', '我很快乐']]\n",
    "with torch.no_grad():\n",
    "    inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)\n",
    "    inputs = {key: inputs[key].cuda() for key in inputs.keys()}\n",
    "    scores = rerank_model(**inputs, return_dict=True).logits.view(-1, ).float()\n",
    "    print(scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "437e3a36-e27a-4153-afc4-7cd47a91c597",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2000/2000 [27:28<00:00,  1.21it/s]\n"
     ]
    }
   ],
   "source": [
    "topn = 30\n",
    "query_ndcg_score = []\n",
    "for query_data in tqdm(data['queries']):\n",
    "    query_feat = model.encode(query_data['text'], normalize_embeddings=True)\n",
    "    query_qid = query_data['id']\n",
    "    query_pids = qrels[query_qid]\n",
    "    \n",
    "    similarities = query_feat.dot(embeddings.T)\n",
    "    top_results = similarities.argsort()[::-1][:topn]\n",
    "\n",
    "    pairs = []\n",
    "    for xi in top_results:\n",
    "        pairs.append([query_data['text'], corpus[xi]])\n",
    "\n",
    "    inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)\n",
    "    with torch.no_grad():\n",
    "        inputs = {key: inputs[key].cuda() for key in inputs.keys()}\n",
    "        scores = rerank_model(**inputs, return_dict=True).logits.view(-1, ).float()\n",
    "\n",
    "    top_results = np.array(top_results)\n",
    "    rerank_result = top_results[scores.cpu().data.numpy().argsort()[::-1]]\n",
    "    \n",
    "    top_results = [data['corpus'][int(x)]['id'] for x in rerank_result]\n",
    "    true_relevance  = [[x in query_pids for x in top_results]]\n",
    "    scores = [list(similarities[similarities.argsort()[::-1][:topn]])]\n",
    "\n",
    "    query_ndcg_score.append(ndcg_score(true_relevance, scores))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "8faef74b-4a04-454f-87fd-bd924edacbf7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.9160629142553329, 0.19732814717552222)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(query_ndcg_score), np.std(query_ndcg_score)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ae10af4-6784-4c60-94df-8364749b9db4",
   "metadata": {},
   "source": [
    "# BCEmbedding（有道）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "901efb29-e5df-48ea-8b0d-d822bc627fbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "\n",
    "# list of sentences\n",
    "sentences = ['sentence_0', 'sentence_1', ...]\n",
    "\n",
    "# init embedding model\n",
    "model = SentenceTransformer(\"./hugging-face-model/maidalun1020/bce-embedding-base_v1\")\n",
    "\n",
    "# set max_length to 512 to avoid an error.\n",
    "model.max_seq_length = 512\n",
    "\n",
    "# extract embeddings\n",
    "embeddings = model.encode(sentences, normalize_embeddings=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "57eb6eb9-6fa4-4613-9366-cc569b141ce3",
   "metadata": {},
   "outputs": [],
   "source": [
    "corpus = [x['text'] for x in data['corpus']][:]\n",
    "embeddings = model.encode(corpus, normalize_embeddings=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "9527618f-f679-4941-98f3-85a58745cc3f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2000/2000 [02:06<00:00, 15.76it/s]\n"
     ]
    }
   ],
   "source": [
    "topn = 30\n",
    "query_ndcg_score = []\n",
    "for query_data in tqdm(data['queries']):\n",
    "    query_feat = model.encode(query_data['text'], normalize_embeddings=True)\n",
    "    query_qid = query_data['id']\n",
    "    query_pids = qrels[query_qid]\n",
    "    \n",
    "    similarities = query_feat.dot(embeddings.T)\n",
    "    top_results = similarities.argsort()[::-1][:topn]\n",
    "    top_results = [data['corpus'][int(x)]['id'] for x in top_results]\n",
    "    true_relevance  = [[x in query_pids for x in top_results]]\n",
    "    scores = [list(similarities[similarities.argsort()[::-1][:topn]])]\n",
    "\n",
    "    query_ndcg_score.append(ndcg_score(true_relevance, scores))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "00a465eb-3971-407b-b01a-5552c4d089d8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.8436841563775636, 0.22071295950607342)"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(query_ndcg_score), np.std(query_ndcg_score)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e76984e-b551-4d86-ac64-2b4292c1d8ab",
   "metadata": {},
   "source": [
    "# BCEmbedding with ReRank"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "e9c90e1e-6dd6-4b8f-a593-20f5cffe4c58",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import CrossEncoder\n",
    "\n",
    "# init reranker model\n",
    "rerank_model = CrossEncoder('./hugging-face-model/maidalun1020/bce-reranker-base_v1', max_length=512)\n",
    "\n",
    "pairs = [['我很开心', '我很沮丧'], ['我很开心', '我很快乐']]\n",
    "scores = rerank_model.predict(pairs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "c635eda9-b57e-4a39-b412-04b60541428e",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_feats = model.encode([x['text'] for x in data['queries']], normalize_embeddings=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "a05dde0c-6a08-4bbd-8cfc-5adbbec482de",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.17447105, 0.19908555, 0.15706639, ..., 0.25002584, 0.2793538 ,\n",
       "       0.28668842], dtype=float32)"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "query_feat.dot(embeddings.T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "231bcd1a-0a23-4ef8-bf9d-9b245e215d84",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2000/2000 [27:58<00:00,  1.19it/s]\n"
     ]
    }
   ],
   "source": [
    "topn = 30\n",
    "query_ndcg_score = []\n",
    "for query_data in tqdm(data['queries']):\n",
    "    query_feat = model.encode(query_data['text'], normalize_embeddings=True)\n",
    "    query_qid = query_data['id']\n",
    "    query_pids = qrels[query_qid]\n",
    "    \n",
    "    similarities = query_feat.dot(embeddings.T)\n",
    "    top_results = similarities.argsort()[::-1][:topn]\n",
    "\n",
    "    pairs = []\n",
    "    for xi in top_results:\n",
    "        pairs.append([query_data['text'], corpus[xi]])\n",
    "\n",
    "    with torch.no_grad():\n",
    "        scores = rerank_model.predict(pairs)\n",
    "\n",
    "    top_results = np.array(top_results)\n",
    "    rerank_result = top_results[scores.argsort()[::-1]]\n",
    "    \n",
    "    top_results = [data['corpus'][int(x)]['id'] for x in rerank_result]\n",
    "    true_relevance  = [[x in query_pids for x in top_results]]\n",
    "    scores = [list(similarities[similarities.argsort()[::-1][:topn]])]\n",
    "\n",
    "    query_ndcg_score.append(ndcg_score(true_relevance, scores))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "f9734497-8143-400e-a620-1944fb104d92",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.46059702380952383, 0.3968857660785824)"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(query_ndcg_score), np.std(query_ndcg_score)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b27fe8e-eef2-4f5b-b9bb-eb4d73785838",
   "metadata": {},
   "source": [
    "# GTE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "e13fa1de-7926-41c1-9586-78c5945f5a56",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.962006   0.95747596]\n",
      " [0.95021254 0.94485617]]\n"
     ]
    }
   ],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "sentences_1 = [\"样例数据-1\", \"样例数据-2\"]\n",
    "sentences_2 = [\"样例数据-3\", \"样例数据-4\"]\n",
    "model = SentenceTransformer('./hugging-face-model/thenlper/gte-small-zh')\n",
    "embeddings_1 = model.encode(sentences_1, normalize_embeddings=True)\n",
    "embeddings_2 = model.encode(sentences_2, normalize_embeddings=True)\n",
    "similarity = embeddings_1 @ embeddings_2.T\n",
    "print(similarity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "9d6e9613-95e0-4ff6-ad62-4ff4eef2528b",
   "metadata": {},
   "outputs": [],
   "source": [
    "corpus = [x['text'] for x in data['corpus']][:]\n",
    "embeddings = model.encode(corpus, normalize_embeddings=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "c9798780-c981-4495-9229-58b967214a6d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2000/2000 [01:58<00:00, 16.90it/s]\n"
     ]
    }
   ],
   "source": [
    "topn = 30\n",
    "query_ndcg_score = []\n",
    "for query_data in tqdm(data['queries']):\n",
    "    query_feat = model.encode(query_data['text'], normalize_embeddings=True)\n",
    "    query_qid = query_data['id']\n",
    "    query_pids = qrels[query_qid]\n",
    "    \n",
    "    similarities = query_feat.dot(embeddings.T)\n",
    "    top_results = similarities.argsort()[::-1][:topn]\n",
    "    top_results = [data['corpus'][int(x)]['id'] for x in top_results]\n",
    "    true_relevance  = [[x in query_pids for x in top_results]]\n",
    "    scores = [list(similarities[similarities.argsort()[::-1][:topn]])]\n",
    "\n",
    "    query_ndcg_score.append(ndcg_score(true_relevance, scores))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "401d4f64-970e-4ff3-a0f6-b50662f259c4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.8563148126485296, 0.21721973465173638)"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(query_ndcg_score), np.std(query_ndcg_score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "508a8c0a-6e85-4e23-9c2a-96177d6dc4cb",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
